[feat] Faster topk algorithm#3009
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an exact clustered Top‑K implementation: new CUDA headers/kernels and TVM bindings, JIT/build and Python dispatch updates to use clustered kernels, benchmark/test updates (including CUDA graph & CUPTI timing tweak), and a small cached device utility for shared‑memory opt‑in. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host Code / User
participant PyAPI as Python Wrapper\n(topk_clusters_* / top_k)
participant TVMBind as TVM FFI Binding
participant Kernel as CUDA Kernel\nfast_topk_cuda_v4
participant Overflow as Global Overflow Cache
Host->>PyAPI: call logits, top_k, options
PyAPI->>PyAPI: allocate indices, values?, cached_overflow
PyAPI->>TVMBind: call fast_topk_clusters_exact*(..., cached_overflow, ...)
TVMBind->>TVMBind: validate shapes/dtypes/strides, get stream
TVMBind->>Kernel: launch specialized kernel on stream
rect rgba(100,150,200,0.5)
Kernel->>Kernel: per-block histograms\ncompute threshold_bin
end
rect rgba(150,100,200,0.5)
Kernel->>Kernel: emit >threshold, cache/spill equals\niterate refinement rounds
Kernel->>Overflow: spill overflow candidates
Overflow-->>Kernel: supply spilled candidates
end
Kernel->>TVMBind: write final indices (and values)
TVMBind->>PyAPI: return tensors
PyAPI->>Host: return results
sequenceDiagram
participant Env as Environment
participant API as top_k()/top_k_*_transform()
participant Selector as Algorithm Selector
participant Clusters as Clustered Path
participant Radix as Radix Path
participant Host as Return Results
Env->>API: FLASHINFER_TOPK_ALGO
API->>Selector: check env var, device, deterministic
alt select clusters
Selector->>Clusters: call clustered wrapper
Clusters->>Host: return indices (+ values, sorted if requested)
else
Selector->>Radix: call existing radix multi-CTA path
Radix->>Host: return indices (+ values)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new cluster-based top-k algorithm optimized for Blackwell GPUs (SM 100/103), featuring standard, page table, and ragged transform implementations. The changes include high-performance CUDA kernels using cooperative groups, TVM FFI bindings, and Python API integration with an environment variable toggle for algorithm selection. Feedback focuses on improving code quality by addressing const correctness in the C++ bindings, replacing magic numbers with descriptive constants in the Python wrappers, and enhancing type safety in shared memory calculations.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/bench_topk.py (1)
92-130:⚠️ Potential issue | 🟠 MajorDon’t overwrite the benchmark mode that the caller selected.
Lines 92, 359, and 432 force
FLASHINFER_TOPK_ALGO="default"before the baseline timing. That breaks--compare-algorithms: the outer loop sets"multi_cta"/"filtered"right before calling these helpers, but both paths get benchmarked as the same"default"baseline. The extra"clusters"run also is not exception-safe, so a failure there leaves later cases with the wrong env.Example approach
+@contextmanager +def temporary_topk_algo(algo: str | None): + previous = os.environ.get("FLASHINFER_TOPK_ALGO") + try: + if algo is None or algo == "auto": + os.environ.pop("FLASHINFER_TOPK_ALGO", None) + else: + os.environ["FLASHINFER_TOPK_ALGO"] = algo + yield + finally: + if previous is None: + os.environ.pop("FLASHINFER_TOPK_ALGO", None) + else: + os.environ["FLASHINFER_TOPK_ALGO"] = previous- set_topk_algo("default") fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes(...) ... - set_topk_algo("clusters") - fast_topk_ms = bench_median_ms(lambda: flashinfer.top_k(scores, k)) - set_topk_algo("auto") + with temporary_topk_algo("clusters"): + fast_topk_ms = bench_median_ms(lambda: flashinfer.top_k(scores, k))Also applies to: 359-398, 432-469
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 92 - 130, The benchmark currently overwrites the caller's selected FLASHINFER_TOPK_ALGO by calling set_topk_algo("default") and later set_topk_algo("clusters")/set_topk_algo("auto"); update bench_topk.py so it preserves and restores the prior top-k algorithm instead of forcing "default": capture the current algo before any set_topk_algo calls, run bench_flashinfer_modes/bench_median_ms as intended, and restore the original algo in a finally block (or remove the initial set_topk_algo("default") entirely if unnecessary) to make the code exception-safe; reference functions/identifiers to change: set_topk_algo, bench_flashinfer_modes, flashinfer.top_k, bench_median_ms, and the blocks that set "clusters" and "auto".
🧹 Nitpick comments (1)
tests/utils/test_topk.py (1)
2303-2327: Strengthen the exact-path assertions.
topk_clusters_exactis the correctness-preserving path, but this still passes if a row contains a few wrong selections as long as overlap stays above the threshold and the per-row min/max happen to match. Comparingvaluesagainstgather(logits, indices)and checking the k-th-value threshold row-wise is a much tighter signal without depending on tie ordering.Example tightening
if output_values: assert values is not None assert values.shape == (batch_size, k) assert values.dtype == dtype abs_err = 0.125 if dtype == torch.bfloat16 else 1e-5 rel_err = 0.1 if dtype == torch.bfloat16 else 1e-5 - torch.testing.assert_close( - values.min(dim=-1).values, - ref_values.min(dim=-1).values, - rtol=rel_err, - atol=abs_err, - ) - torch.testing.assert_close( - values.max(dim=-1).values, - ref_values.max(dim=-1).values, - rtol=rel_err, - atol=abs_err, - ) + gathered_values = torch.gather(logits, dim=-1, index=indices.long()) + torch.testing.assert_close( + values, + gathered_values, + rtol=rel_err, + atol=abs_err, + ) + assert verify_topk_correctness(logits, values, indices.long(), k)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 2303 - 2327, When output_values is True, strengthen exact-path assertions by verifying that returned values exactly match the selected logits and that each row meets the k-th-value threshold: 1) compute gathered = torch.gather(logits, -1, indices) and assert values.shape/dtype and torch.testing.assert_close(values, gathered) to ensure the exact selected entries match; 2) compute kth = torch.topk(logits, k).values[:, -1] (or equivalently gather the k-th threshold via ref_indices) and assert every value in values per row is >= kth (row-wise) to ensure no value below the k-th threshold was selected; keep the existing accuracy check (compute_topk_accuracy(indices, ref_indices.int(), ...)) afterward.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/flashinfer_fast_topk_clusters_binding.cu`:
- Around line 56-73: The dispatch macro DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16
excludes bfloat16 so torch.bfloat16 inputs never reach the launchers; update the
dispatch to include BF16 (or add a separate branch) so the template calls to
launch_fast_topk_clusters_exact (and similar callers at the other spots) are
instantiated with nv_bfloat16 / OrderedBits<nv_bfloat16>. Replace
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 with a macro that also maps the DLPack
dtype for bfloat16 to the correct c_type (nv_bfloat16) or add an explicit
if-case for torch.bfloat16 that calls
launch_fast_topk_clusters_exact<c_type=nv_bfloat16,...> (keeping idx_int64/int
branches and the same argument casts) so BF16 paths reach the kernel.
In `@flashinfer/topk.py`:
- Around line 309-414: The clustered fast_topk kernels are being used
unconditionally; guard topk_clusters_exact, topk_clusters_page_table_transform,
and topk_clusters_ragged_transform (and the other similar helpers noted) behind
the backend/capability gate so we fall back to the radix path unless the device
explicitly supports the clustered backend. Concretely, annotate these APIs with
the `@backend_requirement` decorator and use the provided is_backend_supported() /
is_compute_capability_supported(cc) checks (or the existing helper that checks
FLASHINFER_TOPK_ALGO) to choose the clustered code path only when supported;
otherwise call the original radix implementation (preserve the previous return
types) and do not allocate the large overflow buffers or call fast_topk_* on
unsupported devices. Ensure the same gating is applied to the transform helpers
as well.
In `@include/flashinfer/fast_topk_clusters_exact.cuh`:
- Around line 477-485: In the fast path where seq_len <= TopK you only
initialize output_indices; also write output_values for each slot: inside the
same loop that writes output_indices (use ind_offset and i), set
output_values[ind_offset + i] to the corresponding input value for i < seq_len
(use the same source values array used elsewhere in this function, e.g., values
or input_values) and set output_values[ind_offset + i] to a sentinel for empty
slots (e.g., -INFINITY or numeric_limits<ValueT>::lowest()) when i >= seq_len so
the values tensor is always initialized.
---
Outside diff comments:
In `@benchmarks/bench_topk.py`:
- Around line 92-130: The benchmark currently overwrites the caller's selected
FLASHINFER_TOPK_ALGO by calling set_topk_algo("default") and later
set_topk_algo("clusters")/set_topk_algo("auto"); update bench_topk.py so it
preserves and restores the prior top-k algorithm instead of forcing "default":
capture the current algo before any set_topk_algo calls, run
bench_flashinfer_modes/bench_median_ms as intended, and restore the original
algo in a finally block (or remove the initial set_topk_algo("default") entirely
if unnecessary) to make the code exception-safe; reference functions/identifiers
to change: set_topk_algo, bench_flashinfer_modes, flashinfer.top_k,
bench_median_ms, and the blocks that set "clusters" and "auto".
---
Nitpick comments:
In `@tests/utils/test_topk.py`:
- Around line 2303-2327: When output_values is True, strengthen exact-path
assertions by verifying that returned values exactly match the selected logits
and that each row meets the k-th-value threshold: 1) compute gathered =
torch.gather(logits, -1, indices) and assert values.shape/dtype and
torch.testing.assert_close(values, gathered) to ensure the exact selected
entries match; 2) compute kth = torch.topk(logits, k).values[:, -1] (or
equivalently gather the k-th threshold via ref_indices) and assert every value
in values per row is >= kth (row-wise) to ensure no value below the k-th
threshold was selected; keep the existing accuracy check
(compute_topk_accuracy(indices, ref_indices.int(), ...)) afterward.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e69861f6-01ca-4756-9445-13bbe31545f5
📒 Files selected for processing (7)
benchmarks/bench_topk.pycsrc/flashinfer_fast_topk_clusters_binding.cuflashinfer/jit/topk.pyflashinfer/testing/utils.pyflashinfer/topk.pyinclude/flashinfer/fast_topk_clusters_exact.cuhtests/utils/test_topk.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
csrc/flashinfer_fast_topk_clusters_binding.cu (1)
56-71:⚠️ Potential issue | 🔴 CriticalBF16 still never reaches the clustered launchers.
All three entrypoints still dispatch through
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16, sotorch.bfloat16inputs are rejected before launch.top_k,top_k_page_table_transform, andtop_k_ragged_transformnow route non-deterministic calls here by default, so this is a user-visible regression for advertised BF16 inputs. Please add a BF16-capable dispatch (or an explicitnv_bfloat16branch) at each site.Also applies to: 103-111, 143-151
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/flashinfer_fast_topk_clusters_binding.cu` around lines 56 - 71, The dispatch macro DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 currently blocks torch.bfloat16 inputs; update the dispatch at this call site (wrapping launch_fast_topk_clusters_exact) to include BF16 support by either using/adding a DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16 (or equivalent) macro or adding an explicit nv_bfloat16 branch that calls launch_fast_topk_clusters_exact<c_type, ...> with c_type = nv_bfloat16; make the same change for the other two entrypoints that route here (the top_k, top_k_page_table_transform, and top_k_ragged_transform call sites) so BF16 tensors reach the clustered launchers instead of being rejected.
🧹 Nitpick comments (1)
flashinfer/topk.py (1)
346-382: Wrap the new clustered helpers with@flashinfer_api.
topk_clusters_exact,topk_clusters_page_table_transform, andtopk_clusters_ragged_transformare top-level helpers in this module, but unlike the existing public top-k APIs they currently bypass the standard logging wrapper.As per coding guidelines: Enable API logging in production debugging using `flashinfer_api` decorator and environment variables: FLASHINFER_LOGLEVEL (0/1/3/5) and FLASHINFER_LOGDEST.🧩 Suggested fix
+@flashinfer_api def topk_clusters_exact( logits, top_k, output_values=False, out_dtype=torch.int32, pdl=False ): ... +@flashinfer_api def topk_clusters_page_table_transform( logits, seq_lens, src_page_table, top_k, pdl=False ): ... +@flashinfer_api def topk_clusters_ragged_transform(logits, seq_lens, offsets, top_k, pdl=False): ...Also applies to: 385-413, 416-442
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 346 - 382, The three new top-level helpers (topk_clusters_exact, topk_clusters_page_table_transform, topk_clusters_ragged_transform) must be wrapped with the flashinfer_api decorator to enable standard logging; add `@flashinfer_api` directly above each def while preserving their signatures, and ensure flashinfer_api is imported/available in this module if not already. Place the decorator immediately above the function definitions (no other changes to arguments/return types), so the functions use the FLASHINFER_LOGLEVEL/FLASHINFER_LOGDEST behavior used by the existing public top-k APIs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/topk.py`:
- Around line 357-363: The per-cluster overflow capacity uses floor division and
must be rounded up to avoid under-allocating when max_model_len % num_clusters
!= 0; replace occurrences where topk_global_overflow = max_model_len //
num_clusters with a ceiling division (e.g., (max_model_len + num_clusters - 1)
// num_clusters) and then allocate overflow_buf (and the analogous buffers in
the other two clustered helpers) using that rounded-up topk_global_overflow so
the kernel’s per-cluster overflow_stride never overruns; update all three
clustered helper sites that define topk_global_overflow and allocate
overflow_buf accordingly.
---
Duplicate comments:
In `@csrc/flashinfer_fast_topk_clusters_binding.cu`:
- Around line 56-71: The dispatch macro DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16
currently blocks torch.bfloat16 inputs; update the dispatch at this call site
(wrapping launch_fast_topk_clusters_exact) to include BF16 support by either
using/adding a DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16 (or equivalent)
macro or adding an explicit nv_bfloat16 branch that calls
launch_fast_topk_clusters_exact<c_type, ...> with c_type = nv_bfloat16; make the
same change for the other two entrypoints that route here (the top_k,
top_k_page_table_transform, and top_k_ragged_transform call sites) so BF16
tensors reach the clustered launchers instead of being rejected.
---
Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 346-382: The three new top-level helpers (topk_clusters_exact,
topk_clusters_page_table_transform, topk_clusters_ragged_transform) must be
wrapped with the flashinfer_api decorator to enable standard logging; add
`@flashinfer_api` directly above each def while preserving their signatures, and
ensure flashinfer_api is imported/available in this module if not already. Place
the decorator immediately above the function definitions (no other changes to
arguments/return types), so the functions use the
FLASHINFER_LOGLEVEL/FLASHINFER_LOGDEST behavior used by the existing public
top-k APIs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ba20616d-607b-4a87-9e1b-f1b6aace201c
📒 Files selected for processing (5)
benchmarks/bench_topk.pycsrc/flashinfer_fast_topk_clusters_binding.cuflashinfer/topk.pyflashinfer/utils.pyinclude/flashinfer/fast_topk_clusters_exact.cuh
✅ Files skipped from review due to trivial changes (1)
- include/flashinfer/fast_topk_clusters_exact.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
- benchmarks/bench_topk.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/topk.py (1)
315-316: Remove unused function.
roundup_kbyteis defined but never called anywhere in this file. Consider removing it or documenting its intended future use.🧹 Proposed fix
-def roundup_kbyte(x): - return (x + 1023) // 1024 * 1024 - -🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 315 - 316, The function roundup_kbyte is unused in this module; remove the unused function definition roundup_kbyte from flashinfer/topk.py to eliminate dead code, or if it is intended for future use add a clear docstring and a unit test or TODO comment referencing its intended caller (e.g., any functions that need kilobyte alignment) so its presence is justified—prefer removing it unless you add documentation/tests to show it's required.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 315-316: The function roundup_kbyte is unused in this module;
remove the unused function definition roundup_kbyte from flashinfer/topk.py to
eliminate dead code, or if it is intended for future use add a clear docstring
and a unit test or TODO comment referencing its intended caller (e.g., any
functions that need kilobyte alignment) so its presence is justified—prefer
removing it unless you add documentation/tests to show it's required.
|
Here are benchmarks for fp16: |
|
And here's one for bf16: |
|
Note that all benchmarks are done on B200. |
|
Great work and I'll help review this one, using CTA cluster is definitely a good idea.
Hi @Aalanli, can you clarify about non-deterministic here?
|
|
Hi @yzh119, the order of indices is not guaranteed to stay consistent across runs. If there are tie elements the set of indices is not guaranteed to be the same. In practice the tests assert greater than 99% overlap for f32 and f16 dtype; the set of values should be the same regardless. The algorithm biases earlier elements if there are any tie elements, but due to atomics it's not guaranteed. I think this algorithm has the same properties as the filtered topk algorithm flashinfer already has, for the non-deterministic version. The changes I made to the python API selects this algorithm by default if deterministic=False and sm_100+ is available. Otherwise if the user selects deterministic=True then we get what flashinfer already has. |
| }; | ||
|
|
||
| template <typename T> | ||
| struct OrderedBits; |
There was a problem hiding this comment.
This code is duplicated with RadixTopKTraits in topk.cuh. Should we create a topk_common.cuh file to reuse the struct OrderedBits in both topk.cuh and fast_topk_clusters_exact.cuh?
There was a problem hiding this comment.
Yeah, that sounds good.
There was a problem hiding this comment.
Hi @jiangyinzuo , I moved the common stuff into topk_common.cuh.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/fast_topk_clusters_exact.cuh`:
- Around line 561-583: The helper launch_topk_cluster_kernel currently calls
cudaFuncSetAttribute and cudaLaunchKernelExC without checking their return
values; update launch_topk_cluster_kernel to capture and handle cudaError_t
results from cudaFuncSetAttribute (both calls) and from cudaLaunchKernelExC,
e.g., check the returned error, and on failure either assert in debug builds or
propagate/return the error (or log it) so failures on non‑SM90 hardware are
visible; reference the cudaFuncSetAttribute calls near the top of
launch_topk_cluster_kernel and the final cudaLaunchKernelExC call when adding
these checks.
In `@include/flashinfer/topk_common.cuh`:
- Around line 4-8: The header is missing CUDA headers required for types and
intrinsics used: add the appropriate CUDA includes so symbols like
cuda::std::numeric_limits<float>::infinity(), half and
__half_as_ushort/__ushort_as_half, and
nv_bfloat16/__bfloat16_as_ushort/__ushort_as_bfloat16 are defined; specifically
include <cuda/std/limits> for cuda::std::numeric_limits, <cuda_fp16.h> for half
and half intrinsics, and <cuda_bf16.h> for nv_bfloat16 and its intrinsics
(update include list near the top of topk_common.cuh where other std headers are
listed).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6a813e72-279a-450e-bc99-0da55c5ee316
📒 Files selected for processing (4)
csrc/flashinfer_fast_topk_clusters_binding.cuinclude/flashinfer/fast_topk_clusters_exact.cuhinclude/flashinfer/topk.cuhinclude/flashinfer/topk_common.cuh
✅ Files skipped from review due to trivial changes (1)
- csrc/flashinfer_fast_topk_clusters_binding.cu
|
/bot run |
|
@Aalanli the pre-commit format check is failing: https://github.com/flashinfer-ai/flashinfer/actions/runs/24349849138/job/71303983434?pr=3009 Could you please rerun pre-commit and push? Also, could you either address or resolve the AI code review comments above? |
|
Hi @kahyunnam thanks for taking a look. I addressed AI comments, and some build issues as well as pre-commit. |
There was a problem hiding this comment.
♻️ Duplicate comments (3)
include/flashinfer/fast_topk_clusters_exact.cuh (2)
414-424:⚠️ Potential issue | 🟡 MinorInitialize padded
output_valuesslots too.The valid entries are fixed now, but when
TopK > seq_lenthe else branch still leavesoutput_values[ind_offset + i]untouched whileoutput_indicesis set to-1. That returns garbage in the padded tail of the row.💡 Minimal fix
} else { output_indices[ind_offset + i] = static_cast<IdxT>(-1); + if (output_values != nullptr) { + output_values[ind_offset + i] = RadixTopKTraits<T>::NegInf(); + } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/fast_topk_clusters_exact.cuh` around lines 414 - 424, The padded output_values slots aren't initialized when TopK > seq_len: in the branch inside fast_topk_clusters_exact.cuh where you set output_indices[ind_offset + i] = -1 for i >= seq_len, also set output_values[ind_offset + i] to a defined sentinel (e.g., zero or -INF consistent with your API) so the padded tail doesn't return garbage; update the same loop that uses output_values and logits/logit_offset to initialize output_values in that else branch.
552-574:⚠️ Potential issue | 🟡 MinorSurface CUDA API failures from the launch helper.
cudaFuncSetAttributeandcudaLaunchKernelExCboth returncudaError_t, but this helper ignores them. If the cluster launch or shared-memory opt-in fails, the latercudaGetLastError()in the binding won't pinpoint the failing call and can miss attribute failures entirely.🔍 Minimal debug-visible handling
inline void launch_topk_cluster_kernel(void* kernel, void** args, int grid_dim, int smem_bytes, int num_clusters, bool pdl_enabled, cudaStream_t stream) { - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, MAX_SMEM_CARVEOUT); - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); + auto err = + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, MAX_SMEM_CARVEOUT); + assert(err == cudaSuccess); + err = cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100); + assert(err == cudaSuccess); cudaLaunchConfig_t config; config.numAttrs = 0; @@ config.dynamicSmemBytes = smem_bytes; config.gridDim = grid_dim; config.stream = stream; config.attrs = attribute; - cudaLaunchKernelExC(&config, kernel, args); + err = cudaLaunchKernelExC(&config, kernel, args); + assert(err == cudaSuccess); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/fast_topk_clusters_exact.cuh` around lines 552 - 574, The helper launch_topk_cluster_kernel currently ignores return values from cudaFuncSetAttribute and cudaLaunchKernelExC so failures are hidden; fix it by making launch_topk_cluster_kernel return cudaError_t (instead of void), check the cudaError_t result after each cudaFuncSetAttribute and after cudaLaunchKernelExC, and immediately return the error on failure (or propagate it) so callers can handle/log it; update callers to handle the returned cudaError_t and propagate or log it accordingly. Use the existing symbols cudaFuncSetAttribute, cudaLaunchKernelExC, and launch_topk_cluster_kernel to locate and change the code.include/flashinfer/topk_common.cuh (1)
4-7:⚠️ Potential issue | 🟠 MajorMake
topk_common.cuhself-contained.
RadixTopKTraitsnow namescuda::std::numeric_limits,half, andnv_bfloat16, but this header still only includes libc headers. That leaves it dependent on include order, andinclude/flashinfer/fast_topk_clusters_exact.cuhcurrently only brings in<cuda_fp16.h>before including it, socuda::stdandnv_bfloat16are still unresolved here.🔧 Minimal fix
+#include <cuda_bf16.h> +#include <cuda_fp16.h> +#include <cuda/std/limits> + `#include` <cstdint> `#include` <cstdlib> `#include` <numeric> `#include` <type_traits>🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk_common.cuh` around lines 4 - 7, RadixTopKTraits in topk_common.cuh relies on cuda::std::numeric_limits and the types half and nv_bfloat16 but the header only includes libc headers; make topk_common.cuh self-contained by adding the necessary includes (e.g., <limits> to provide numeric_limits, <cuda_fp16.h> to provide half, and <cuda_bf16.h> to provide nv_bfloat16) so RadixTopKTraits compiles independently of include order.
🧹 Nitpick comments (1)
include/flashinfer/fast_topk_clusters_exact.cuh (1)
209-215: Document why this branch spills instead of taking the cheaper alternatives.The current comment explains the mechanism, but not why the exact path chooses global spill over dropping overflowed candidates or doing another pass. One sentence on that trade-off would make future tuning much safer. As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/fast_topk_clusters_exact.cuh` around lines 209 - 215, Add a one-sentence justification above the spill branch explaining why the code chooses to spill to the per-CTA global overflow cache (using s_cached_overflow_count, overflow_stride, get_cached_overflow and writing PackedCachedData) instead of cheaper alternatives like dropping overflowed candidates or performing another pass; mention the trade-off: preserving candidate correctness / avoiding additional kernel passes at the cost of a rare global spill and minimal memory overhead, and note that alternatives were considered but rejected due to increased error rate or extra synchronization/latency.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@include/flashinfer/fast_topk_clusters_exact.cuh`:
- Around line 414-424: The padded output_values slots aren't initialized when
TopK > seq_len: in the branch inside fast_topk_clusters_exact.cuh where you set
output_indices[ind_offset + i] = -1 for i >= seq_len, also set
output_values[ind_offset + i] to a defined sentinel (e.g., zero or -INF
consistent with your API) so the padded tail doesn't return garbage; update the
same loop that uses output_values and logits/logit_offset to initialize
output_values in that else branch.
- Around line 552-574: The helper launch_topk_cluster_kernel currently ignores
return values from cudaFuncSetAttribute and cudaLaunchKernelExC so failures are
hidden; fix it by making launch_topk_cluster_kernel return cudaError_t (instead
of void), check the cudaError_t result after each cudaFuncSetAttribute and after
cudaLaunchKernelExC, and immediately return the error on failure (or propagate
it) so callers can handle/log it; update callers to handle the returned
cudaError_t and propagate or log it accordingly. Use the existing symbols
cudaFuncSetAttribute, cudaLaunchKernelExC, and launch_topk_cluster_kernel to
locate and change the code.
In `@include/flashinfer/topk_common.cuh`:
- Around line 4-7: RadixTopKTraits in topk_common.cuh relies on
cuda::std::numeric_limits and the types half and nv_bfloat16 but the header only
includes libc headers; make topk_common.cuh self-contained by adding the
necessary includes (e.g., <limits> to provide numeric_limits, <cuda_fp16.h> to
provide half, and <cuda_bf16.h> to provide nv_bfloat16) so RadixTopKTraits
compiles independently of include order.
---
Nitpick comments:
In `@include/flashinfer/fast_topk_clusters_exact.cuh`:
- Around line 209-215: Add a one-sentence justification above the spill branch
explaining why the code chooses to spill to the per-CTA global overflow cache
(using s_cached_overflow_count, overflow_stride, get_cached_overflow and writing
PackedCachedData) instead of cheaper alternatives like dropping overflowed
candidates or performing another pass; mention the trade-off: preserving
candidate correctness / avoiding additional kernel passes at the cost of a rare
global spill and minimal memory overhead, and note that alternatives were
considered but rejected due to increased error rate or extra
synchronization/latency.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a6b88abe-9adf-4b60-9a41-e20f14433ba8
📒 Files selected for processing (2)
include/flashinfer/fast_topk_clusters_exact.cuhinclude/flashinfer/topk_common.cuh
|
@Aalanli is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
Hi @kahyunnam , I think the build failure is not due to the changes in this PR, do you know what's the issue? |
|
@Aalanli I think the build is just flakey, we can go ahead and merge. The "Test Results Summary" step is still ongoing but I've enabled automerge |
Head branch was pushed to by a user without write access
|
Hi @kahyunnam, I found and fixed an edge case that manifested only some of the time in bfloat16 case (the last bin could be distributed in such a way that the threshold bin contains some topk values), the performance is still competitive. The other tests that previously failed were due to Eg: |
📌 Description
This PR implements a faster topk algorithm that uses sm90+ CTA clusters feature. This is a non-deterministic algorithm, but does not drop indices and instead overflows to global memory. Benchmark results show that it's faster than both the multi-cta topk algorithm and the filtering algorithm overall. The cases it's slower is when the overflow happens too much.
Note: Speedup is speedup of flashinfer vs torch, while Speedup Clusters vs. Default is speed up of this kernel over flashinfer.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Benchmarks
Tests
Chores
Style